{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "fpath = 'qa_corpus.csv'\n",
    "import pandas as pd\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "\n",
    "content = pd.read_csv(fpath, encoding='utf-8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import jieba\n",
    "import re\n",
    "\n",
    "\n",
    "import os\n",
    "LTP_DATA_DIR = 'D:/ProgramData/nlp_package/ltp_v34'  # ltp模型目录的路径\n",
    "cws_model_path = os.path.join(LTP_DATA_DIR, 'cws.model')  # 分词模型路径，模型名称为`cws.model`\n",
    "\n",
    "from pyltp import Segmentor\n",
    "segmentor = Segmentor()  # 初始化实例\n",
    "segmentor.load_with_lexicon(cws_model_path, 'lexicon_seg.txt') # 加载外部词典文件路径\n",
    "\n",
    "\n",
    "def token(string):\n",
    "    return re.findall(r'[\\d|\\w]+', string)\n",
    "\n",
    "def cut(string): \n",
    "    return ' '.join(segmentor.segment(string))\n",
    "\n",
    "def filter_text(content):\n",
    "    q_content = content['question'].tolist()\n",
    "    q_content = [token(str(n)) for n in q_content]\n",
    "    q_content = [' '.join(n) for n in q_content]\n",
    "    q_content = [cut(n) for n in q_content]\n",
    "    return q_content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_content = filter_text(content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "stopwords = []\n",
    "with open('chinese_stopwords.txt', 'r', encoding='utf-8') as f:\n",
    "    for line in f.readlines():\n",
    "        if len(line.strip()) < 2:\n",
    "            stopwords.append(line.strip())\n",
    "\n",
    "with open('哈工大停用词表.txt', 'r', encoding='utf-8') as f:\n",
    "    for line in f.readlines():\n",
    "        if len(line.strip()) < 2:\n",
    "            stopwords.append(line.strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorized = TfidfVectorizer(max_features=12000, tokenizer=lambda x: x.split(), stop_words=stopwords)\n",
    "X = vectorized.fit_transform(q_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.sparse as sp\n",
    "\n",
    "\n",
    "X_array = X.toarray()\n",
    "X_array = sp.csr_matrix(X_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "(array([ 5278,  9413,  9499, 10423], dtype=int64),)"
     },
     "metadata": {},
     "execution_count": 8
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "np.nonzero(X[100].toarray()[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import cosine\n",
    "\n",
    "def distance(v1, v2): return cosine(v1, v2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from operator import and_\n",
    "from functools import reduce\n",
    "\n",
    "\n",
    "pos_model_path = os.path.join(LTP_DATA_DIR, 'pos.model')\n",
    "\n",
    "from pyltp import Postagger\n",
    "postagger = Postagger() # 初始化实例\n",
    "postagger.load_with_lexicon(pos_model_path, 'lexicon.txt')  # 加载模型\n",
    "\n",
    "and_pos_set = {'n', 'v', 'm', 'nh', 'ni', 'nl', 'ns', 'nt', 'ws'}\n",
    "\n",
    "\n",
    "def token(string):\n",
    "    return re.findall(r'[\\d|\\w]+', string)\n",
    "\n",
    "def cut(string): \n",
    "    return segmentor.segment(string)\n",
    "\n",
    "def filter_text_single(string):\n",
    "    q_content = token(string.lower())\n",
    "    print(q_content)\n",
    "    q_content = ' '.join(q_content)\n",
    "    q_content = cut(q_content)\n",
    "    return q_content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_2_id = vectorized.vocabulary_\n",
    "id_2_word = {d: w for w, d in word_2_id.items()}\n",
    "\n",
    "inverse_idx = X_array.transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def search_connect_doc(query):\n",
    "    \"\"\"\"\"\"\n",
    "    words = filter_text_single(query)\n",
    "    postags = postagger.postag(words)\n",
    "\n",
    "    to_and = []\n",
    "    for i, postag in enumerate(postags):\n",
    "        if postag in and_pos_set:\n",
    "            to_and.append(words[i])\n",
    "    print(to_and)\n",
    "    \n",
    "    query_vec = vectorized.transform([' '.join(words)]).toarray()\n",
    "    try:\n",
    "        candidates_ids = [word_2_id[w] for w in to_and]\n",
    "    except KeyError:\n",
    "        pass\n",
    "\n",
    "    documents_ids = [\n",
    "         set(np.nonzero(inverse_idx[_id].toarray()[0])[0]) for _id in candidates_ids\n",
    "    ]\n",
    "\n",
    "    merged_documents = reduce(and_, documents_ids)\n",
    "\n",
    "    return merged_documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "['代发工资']\n['代发', '工资']\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "{1,\n 24425,\n 24821,\n 24822,\n 24823,\n 24824,\n 25316,\n 25317,\n 25320,\n 25321,\n 25748,\n 25749,\n 25799,\n 26738,\n 26739}"
     },
     "metadata": {},
     "execution_count": 25
    }
   ],
   "source": [
    "search_connect_doc(content.question[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "'代发工资客户个人信用消费贷款用途的使用范围'"
     },
     "metadata": {},
     "execution_count": 26
    }
   ],
   "source": [
    "content.question[24425]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python36864bit023718609e434315a7782a7404fb6072",
   "display_name": "Python 3.6.8 64-bit"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}